【踩坑篇】BN
tensorflow 中实现BN的方式有多重:
简介
目前的BN的操作均基于2015年google提出的《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》,但是在实际操作中发现有很多需要特别注意的点
- 基于mini-batch进行训练,要保证训练和测试的数据的同分布
- 不同batch的分布的稳定性
公式:y=γ(x-μ)/σ+β
tf.nn.batch_normalization
相关的参数说明:
- x 输入数据
- mean 样本均值
- variance 样本标准差
- offset 样本的偏倚
- scale 样本的缩放
因为要考虑训练和测试数据的同质性,故在进行BN时操作是不同的
# 定义BN相关的占位符(给定维度size和权重分配decay)
# pop_mean和pop_vari为训练数据的整体情况的整合
gamma = tf.variable(tf.ones[size])
beta = tf.variable(tf.zeros[size])
pop_mean = tf.variable(tf.zeros[size], trainable=False)
pop_vari = tf.variable(tf.ones[size], trainable=False)
# 针对训练集,是按照batch的均值进行计算训练,故BN时为真实值+偏置后送入模型
batch_mean, batch_variance = tf.nn.moments(layer, [0])
train_mean = tf.assign(pop_mean, pop_mean*decay + batch_mean*(1-decay))
train_vari = tf.assign(pop_vari, pop_vari*decay + batch_vari*(1-decay))
bn_layer = tf.nn.batch_normalization(layer, batch_mean, batch_variance, beta, gamma)
# 针对测试集
bn_layer = tf.nn.batch_normalization(layer, pop_mean, pop_vari, beta, gamma)
tf.layers.batch_normalization
相关参数说明:
- inputs 输入数据
- momentum 训练时整体和batch数据作用
- training 执行类型
因为要考虑训练和测试数据的同质性,故在进行BN时操作是不同的
# 训练数据
bn_layer = tf.layers.batch_normalization(layer, momentum=decay, traning=True)
# 测试数据
bn_layer = tf.layers.batch_normalization(layer, traning=False)